# -*- coding: utf-8 -*-
"""
Created on Thu Nov 28 17:00:16 2019
"""
import json
import threading
from queue import Queue, Empty
from collections.abc import Iterable
from urllib.parse import urlencode
from urllib.request import Request, urlopen
import dateutil.parser
import pandas as pd
from util.dfcf import settings
class EastmoneyExporter:
    '''
    An abstruct class for Eastmoney Exporters.
    '''
    def data_loader(self, data):
        '''
        data loader is used when parsing the content from web 
        to a python object.

        data : str
            data to be parsed.
        '''
        return data
    
    @staticmethod
    def get_sec_id(code, asset_type):
        '''
        eastmoney use a category number to distinguish the
        index code from a individual stock code.

        Paramters
        ----------
        asset_type : str
            if asset_type is 'I' (index), the prefix is '1'
            if the asset_type is 'E' (equity), then prefix is '0' for 
            stocks in Shanghai, '1' for stocks in Shenzhen.
        code : str 
            the code of sec/index to be exported.
        '''
        if asset_type == 'I':
            prefix = '1'
        elif asset_type == 'E':
            if code[0] in ('6', '9'):
                prefix = '1'
            elif code[0] in ('0', '3'):
                prefix = '0'
            else:
                raise ValueError(f'unknown code {code}')
        else:
            raise ValueError(f'unknown asset_type {asset_type}')
        return f'{prefix}.{code}'
    
    @staticmethod
    def get_request_url(url, data):
        url_data = urlencode(data)
        return url + '?' + url_data
    
    def get_parsed_data(self, url, data):
        '''
        read data string from web and then transform it to python object.
        transformation funtion is self.data_loader

        Parameters
        ----------
        url : str
            the url used to fetch the data.
        data : dict
            the data used in GET method to construct url.
        '''
        with urlopen(Request(self.get_request_url(url, data))) as data:
            parsed_data = self.data_loader(data.read().decode())
        return parsed_data

class EastmoneyRealtime(EastmoneyExporter):
    '''
    An exporter used for exporting realtime data from eastmoney.
    '''
    def __init__(self, field_code_mapping=None, field_parser_mapping=None):
        '''
        init the exporter.
        self.field_code_mapping uses DEFAULT_RT_CODE_MAPPING by default.
        self.field_parser_mapping uses DEFAULT_RT_PARSER_MAPPING by default.
        if field_code_mapping is specified, it will be used to update
        self.field_code_mapping, so does field_parser_mapping.

        Parameters
        ----------
        field_code_mapping : a dictionary from human-readable field name
            to field code used by eastmoney.
        field_parser_mapping : a dictionary from field code to a parser used
            to convert an exported string to a proper format.
        '''
        # Get the default mapping fields settings
        self.field_code_mapping = settings.DEFAULT_RT_CODE_MAPPING.copy()
        self.field_parser_mapping = settings.DEFAULT_RT_PARSER_MAPPING.copy()
        # update the mapping if user provide it.
        if field_code_mapping is not None:
            self.field_code_mapping.update(field_code_mapping)
        if field_parser_mapping is not None:
            self.field_parser_mapping.update(field_code_mapping)
        self.data_loader = json.loads

    def get_field_names(self, fields=None):
        '''
        get the field_name specified in fields. if fields is None,
        return all the available field names in self.field_code_mapping.

        Parameters
        ----------
        fields : None or Sequence
            the fields to be exported.
        '''
        fields = fields or list(self.field_code_mapping)
        fields_list = [(field_name, int(field_code[1:]))
        for field_name, field_code in self.field_code_mapping.items() \
            if field_name in fields]
        fields_list.sort(key=lambda x: x[1])
        return [field_info[0] for field_info in fields_list]
    
    def get_field_code(self, field):
        '''
        get the field_code in accordance with self.field_code_mapping.
        note the field must be either in field_code_mapping.keys() or
        field_code_mapping.values(). otherwise, ValueError will be raised.

        Parameters
        ----------
        field : human-readable field_name to be converted the field code. 
        '''
        if field in self.field_code_mapping:
            field_code = self.field_code_mapping[field]
        elif field in self.field_code_mapping.values():
            field_code = field
        else:
            raise ValueError(f'field {field} is not a valid field.')
        return field_code
    
    def get_field_parser(self, field_code):
        '''
        get field_parser used when parse the acquiired columns specified
        by field_code.
        
        Parameters
        ----------
        field_code: code used by east_money.
        '''
        return self.field_parser_mapping.get(field_code) or settings.DEFAULT_FIELD_PARSER
    
    def get_parsed_data(self, code, field_codes, asset_type):
        '''
        the method aims to get the parsed data.
        
        Parameters
        ----------
        code : str
            code of sec/index to be exported.
        field_codes : str
            field_codes used by eastmoney to fetch data.
        asset_type : str
            the same as asset_type in self.get_sec_id.
        '''
        url = 'http://push2his.eastmoney.com/api/qt/stock/trends2/get'
        req_data = {
            # 东方财富通过secid区分具体的证券，
            # .之前的的种类，1代表指数和沪市股票2代表深市股票。
            # .之后的为具体的具体的证券代码。1代表指数或者沪市股票，0代表深市股票。
            'secid': self.get_sec_id(code, asset_type),
            'fields1': 'f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11',
            # fields2 参数指定了数据字段。已知f51, f53, f56, f58分别对应
            # 交易时间，现价，成交量，当日均价（个股）/ 领先指数（指数）
            'fields2': field_codes,
            'iscr': 0,
            'ndays': 1,
        }
        return super(EastmoneyRealtime, self).get_parsed_data(url, req_data)


    def get_sec_rt_quote(self, code, fields=None, asset_type='I'):
        '''
        get realtime data (a DataFrame) from eastmoney.

        Parameters
        ----------
        code : str
            code of sec/index to be exported.
        fields : Sequence
            a sequence of fields (either field name of field code) to be exported.
        asset_type : str
            the same as asset_type in self.get_sec_id.
        '''
        fields = self.get_field_names(fields)
        field_codes = ''
        field_parsers = [] # used when parse string to python object.
        for field in fields:
            field_code = self.get_field_code(field)
            field_parsers.append(self.get_field_parser(field_code))
            field_codes = ','.join([field_codes, field_code]) \
                 if field_codes else field_code
        parsed_data = self.get_parsed_data(code, field_codes, asset_type)

        trends_str_list = [trend_element.split(',') for
                        trend_element in
                        parsed_data['data']['trends']]

        trends_data_list = [
            [parser(field_data)
                for field_data, parser in zip(trend_str, field_parsers)]
            for trend_str in trends_str_list
        ]

        df_index_data = pd.DataFrame(
            trends_data_list,
            columns=fields
        )
        return df_index_data

    def get_mul_sec_rt_quote(self, codes, fields=None, asset_type='I', thrd_num=8):
        '''
        the method is used when multiple codes are provided.
        it utilizes multiple threads to acquire multiple codes at the same time.

        Parameters
        ----------
        codes : Sequence
            codes of sec/index to be exported.
        fields : Sequence
            a sequence of fields (either field name of field code) to be exported.
        asset_type : str
            the same as asset_type in self.get_sec_id.
        thrd_num : int
            number of threads used at the same time to export data.
        '''
        def get_rt_quote(que, output_lock):
            while True:
                try:
                    code = que.get(timeout=1)
                    df_data = self.get_sec_rt_quote(code, fields, asset_type)
                    df_data['code'] = code
                    with output_lock:
                        output_list.append(df_data)
                    que.task_done()
                except Empty:
                    break
        
        que = Queue()
        for code in codes:
            que.put(code)
        output_list = []
        output_lock = threading.Lock()
        for i in range(thrd_num):
            thrd = threading.Thread(target=get_rt_quote, 
                args=(que, output_lock), 
                daemon=True)
            thrd.start()
        que.join()
        df_data_total = pd.concat(output_list, ignore_index=True)
        return df_data_total
        

if __name__ == '__main__':
    # df_index_data = get_eastmoney_quote('000001', None, 'E')
    import time
    start_time = time.time()

    df_index_data = EastmoneyRealtime({
        'f52': 'f52', 
        'f54': 'f54', 
        'f55': 'f55'
    }).get_mul_sec_rt_quote(['600000', '000001', '600519', '601888'], None, 'E', thrd_num=4)
    time_spent = time.time() - start_time
    print(f'The time elapsed is {time_spent:.5f}.')
